/***************************************************************************************************
 * Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
 * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
 *reserved. SPDX-License-Identifier: BSD-3-Clause
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * 3. Neither the name of the copyright holder nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
 *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 *POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
  \brief Epilogue for threadblock scoped GEMMs using Tensor Ops.

  The epilogue rearranges the result of a matrix product through shared memory
  to match canonical tensor layouts in global memory. Epilogues support
  conversion and reduction operations.

*/

#pragma once

#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif

#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/functional.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/vector.h"
#include "cutlass/numeric_types.h"  // NOLINT
#include "cutlass/tensor_coord.h"

#include "cutlass/gemm/gemm.h"

#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"

#include "cutlass/epilogue/threadblock/epilogue_base.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/numeric_types.h"  // NOLINT

namespace cutlass {
namespace epilogue {
namespace threadblock {

// Epilogue operator
template <typename Shape_,  // Shape of threadblock tile (concept: GemmShape)
          typename WarpMmaOperator_,  // Warp-level MMA operator (concept:
                                      // gemm::warp::MmaTensorOp)
          int PartitionsK,            // Number of partitions of the K dimension
          typename OutputTileIterator_,   // Tile iterator reading and writing
                                          // output tensors
          typename OutputTileIterator2_,  // Tile iterator reading and writing
                                          // output tensors
          typename AccumulatorFragmentIterator_,  // Fragment iterator
                                                  // selecting accumulators
          typename WarpTileIterator_,    // Warp-scoped tile iterator writing
                                         // accumulators to SMEM
          typename SharedLoadIterator_,  // Threadblock-scoped tile iterator
                                         // loading from SMEM
                                         // Output operator
          typename OutputOp0_,
          typename OutputOp1_,
          typename OutputOp2_,
          typename Padding_,  // Padding added to SMEM allocation to avoid
                              // bank conflicts (concept: MatrixShape)
          bool StoreD0 = true,
          bool StoreD1 = true,
          int FragmentsPerPartition =
              1,                  // Used to coarsten the epilogue granularity
          int IterationsUnroll =  // Used to reduce binary size when epilogue
                                  // op is large
          (!IsEpilogueFunctorHeavy<OutputOp0_>::value)>
class DualEpilogue {
    public:
    using Base = EpilogueBase<Shape_,
                              typename WarpMmaOperator_::Shape,
                              PartitionsK,
                              AccumulatorFragmentIterator_,
                              WarpTileIterator_,
                              Padding_,
                              FragmentsPerPartition>;

    using Shape = Shape_;
    using WarpMmaOperator = WarpMmaOperator_;
    static int const kPartitionsK = PartitionsK;
    static bool constexpr kStoreD0 = StoreD0;
    static bool constexpr kStoreD1 = StoreD1;
    using OutputTileIterator = OutputTileIterator_;
    using OutputTileIterator2 = OutputTileIterator2_;
    using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
    using WarpTileIterator = WarpTileIterator_;
    using SharedLoadIterator = SharedLoadIterator_;
    using OutputOp0 = OutputOp0_;
    using OutputOp1 = OutputOp1_;
    using OutputOp2 = OutputOp2_;
    using Padding = Padding_;

    using Layout = layout::RowMajor;
    using LongIndex = typename Layout::LongIndex;

    // The complete warp-level accumulator tile
    using AccumulatorTile = typename Base::AccumulatorTile;

    // Accumulator element
    using ElementAccumulator = typename WarpTileIterator::Element;

    // Output element
    using ElementOutput = typename OutputTileIterator::Element;

    // Output access size
    static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;

    // Tensor reference to destination tensor
    using TensorRef = typename OutputTileIterator::TensorRef;

    // Tensor reference to sync tensor
    using SyncTensorRef =
        typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;

    // Const tensor reference to source tensor
    using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;

    // Array type used to output
    using OutputAccessType = Array<typename OutputTileIterator::Element,
                                  OutputTileIterator::kElementsPerAccess>;

    // Array type used to output
    using OutputAccessType2 = Array<typename OutputTileIterator2::Element,
                                    OutputTileIterator2::kElementsPerAccess>;

    // Array type used by output functor
    using AccumulatorAccessType = Array<typename WarpTileIterator::Element,
                                        OutputTileIterator::kElementsPerAccess>;

    // Number of warps
    using WarpCount = typename Base::WarpCount;

    struct SharedStorage {
        using Element = typename WarpTileIterator::Element;

        // Tensor reference to shared memory allocation
        using TensorRef = typename WarpTileIterator::TensorRef;

        // Logical shape of the shared memory tile written to by all warps.
        using Shape = typename Base::Shape;

        // Shape of the shared memory allocation for the epilogue
        using StorageShape = typename Base::SharedStorage::StorageShape;

        //
        // Data members
        //

        AlignedBuffer<Element, StorageShape::kCount> storage[2];

        //
        // Methods
        //

        // Returns a tensor reference to the shared memory buffer
        CUTLASS_DEVICE
        TensorRef reference(int i) {
            return TensorRef(
                storage[i].data(),
                Layout::packed({StorageShape::kRow, StorageShape::kColumn}));
        }
    };

    static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1
                                          ? Base::kFragmentsPerIteration
                                          : kPartitionsK;
    static int constexpr kSmemPointerOffset =
        SharedStorage::StorageShape::kCount / kSmemTiles;

    public:
    static_assert(
        SharedLoadIterator::Fragment::kElements ==
            OutputTileIterator::Fragment::kElements,
        "Mismatch between shared load iterator and output tile iterator.");

    static_assert(OutputTileIterator::kElementsPerAccess,
                  "OutputTileIterator::kElementsPerAccess must not be zero.");

    static_assert(!(OutputTileIterator::Fragment::kElements %
                    OutputTileIterator::kElementsPerAccess),
                  "Divisibility");

    private:
    // Loads fragment from shared memory aligned with output tensor
    SharedLoadIterator shared_load_iterator0_;
    SharedLoadIterator shared_load_iterator1_;

    // Stores a warp's fragment of accumulators to SMEM
    WarpTileIterator warp_tile_iterator0_;
    WarpTileIterator warp_tile_iterator1_;

    public:
    // Constructor
    CUTLASS_DEVICE
    DualEpilogue(
        SharedStorage &shared_storage,  // Shared storage object  // NOLINT
        int thread_idx,                 // ID of a thread within the threadblock
        int warp_idx,                   // ID of warp within threadblock
        int lane_idx                    // Id of thread within warp
        )
        : shared_load_iterator0_(shared_storage.reference(0), thread_idx),
          shared_load_iterator1_(shared_storage.reference(1), thread_idx),
          warp_tile_iterator0_(shared_storage.reference(0), lane_idx),
          warp_tile_iterator1_(shared_storage.reference(1), lane_idx) {
        int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN);
        int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN);
        int warp_m = warp_mn % WarpCount::kM;
        int warp_n = warp_mn / WarpCount::kM;

        MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n};

        warp_tile_iterator0_.add_tile_offset(warp_offset);
        warp_tile_iterator1_.add_tile_offset(warp_offset);
    }

    // Streams the result to global memory
    CUTLASS_DEVICE
    void operator()(OutputOp0 const &output_op0,
                    OutputOp1 const &output_op1,
                    OutputOp2 const &output_op2,
                    OutputTileIterator dest0,
                    OutputTileIterator dest1,
                    OutputTileIterator2 dest2,
                    AccumulatorTile const &accumulator0,
                    AccumulatorTile const &accumulator1,
                    OutputTileIterator source_iterator[2],
                    bool writeToD2  // true if it's the final split-k
    ) {
      // TODO: Implement when no source is needed  // NOLINT
      typename OutputTileIterator::Fragment source_fragment[2];
      CUTLASS_PRAGMA_UNROLL
      for (int i = 0; i < 2; ++i) {
          source_fragment[i].clear();
      }

      //
      // Iterator over warp-level accumulator fragment
      //

      AccumulatorFragmentIterator accum_fragment_iterator[2] = {accumulator0,
                                                                accumulator1};

      //
      // Iterate over accumulator tile
      //

      #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
      for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
          //
          // Load the source
          //

          CUTLASS_PRAGMA_UNROLL
          for (int i = 0; i < 2; ++i) {
              source_iterator[i].load(source_fragment[i]);
              ++source_iterator[i];
          }

          //
          // Convert and store fragment
          //

          __syncthreads();

          acc2smem_source_needed<cutlass::make_index_sequence<
              OutputTileIterator::kIterations>>::push(iter,
                                                      accum_fragment_iterator[0],
                                                      this->warp_tile_iterator0_);
          acc2smem_source_needed<cutlass::make_index_sequence<
              OutputTileIterator::kIterations>>::push(iter,
                                                      accum_fragment_iterator[1],
                                                      this->warp_tile_iterator1_);

          __syncthreads();

          //
          // Load fragments from shared memory
          //

          typename SharedLoadIterator::Fragment
              aligned_accum_fragment0[kPartitionsK];
          typename SharedLoadIterator::Fragment
              aligned_accum_fragment1[kPartitionsK];

          shared_load_iterator0_.load(aligned_accum_fragment0[0]);
          shared_load_iterator1_.load(aligned_accum_fragment1[0]);

          // If the number of k-slices is > 1 - perform a reduction amongst the
          // k-slices
          if (kPartitionsK > 1) {
            plus<typename SharedLoadIterator::Fragment> add_fragments;

            CUTLASS_PRAGMA_UNROLL
            for (int i = 1; i < kPartitionsK; ++i) {
              shared_load_iterator0_.add_pointer_offset(kSmemPointerOffset);
              shared_load_iterator1_.add_pointer_offset(kSmemPointerOffset);
              shared_load_iterator0_.load(aligned_accum_fragment0[i]);
              shared_load_iterator1_.load(aligned_accum_fragment1[i]);
              aligned_accum_fragment0[0] = add_fragments(
                  aligned_accum_fragment0[0], aligned_accum_fragment0[i]);
              aligned_accum_fragment1[0] = add_fragments(
                  aligned_accum_fragment1[0], aligned_accum_fragment1[i]);
            }

            shared_load_iterator0_.add_pointer_offset((1 - kPartitionsK) *
                                                      kSmemPointerOffset);
            shared_load_iterator1_.add_pointer_offset((1 - kPartitionsK) *
                                                      kSmemPointerOffset);
          }

          //
          // Compute the output result
          //

          typename OutputTileIterator::Fragment output_fragment[2];
          typename OutputTileIterator2::Fragment output_fragment_final;

          apply_output_operator_(output_fragment,
                                output_fragment_final,
                                output_op0,
                                output_op1,
                                output_op2,
                                aligned_accum_fragment0[0],
                                aligned_accum_fragment1[0],
                                source_fragment);

          //
          // Store the final result
          //

          if (kStoreD0) {
              dest0.store(output_fragment[0]);
              ++dest0;
          }
          if (kStoreD1) {
              dest1.store(output_fragment[1]);
              ++dest1;
          }
          if (writeToD2) {
              dest2.store(output_fragment_final);
              ++dest2;
          }
      }
    }

    private:
    static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1,
                  "One of these must be exactly 1.");

    template <class Seq>
    struct acc2smem_source_needed;

    template <size_t... Seq>
    struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
        template <int Advance>
        CUTLASS_DEVICE static void helper(
            AccumulatorFragmentIterator accum_fragment_iterator,
            WarpTileIterator &warp_tile_iterator) {  // NOLINT
            CUTLASS_PRAGMA_UNROLL
            for (int i = 0; i < Advance; i++) {
              ++accum_fragment_iterator;
            }

            typename AccumulatorFragmentIterator::Fragment accum_fragment;
            accum_fragment_iterator.load(accum_fragment);
            warp_tile_iterator.store(accum_fragment);
        }

        CUTLASS_DEVICE
        static void push(size_t pos,
                        AccumulatorFragmentIterator const &iterator_begin,
                        WarpTileIterator &warp_tile_iterator) {  // NOLINT
            int dummy[] = {(pos == Seq) &&
                          (helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
        }
    };

    // Helper to invoke the output functor over each vector of output
    CUTLASS_DEVICE
    void apply_output_operator_(
        typename OutputTileIterator::Fragment (&output_fragment)[2],
        typename OutputTileIterator2::Fragment &output_fragment_final,  // NOLINT
        OutputOp0 const &output_op0,
        OutputOp1 const &output_op1,
        OutputOp2 const &output_op2,
        typename SharedLoadIterator::Fragment const &aligned_accum_fragment0,
        typename SharedLoadIterator::Fragment const &aligned_accum_fragment1,
        typename OutputTileIterator::Fragment const (&source_fragment)[2]) {
        OutputAccessType *output_frag_ptr[2] = {
            reinterpret_cast<OutputAccessType *>(&output_fragment[0]),
            reinterpret_cast<OutputAccessType *>(&output_fragment[1])};

        OutputAccessType2 *output_frag_final_ptr =
            reinterpret_cast<OutputAccessType2 *>(&output_fragment_final);

        AccumulatorAccessType const *compute_frag_ptr[2] = {
            reinterpret_cast<AccumulatorAccessType const *>(
                &aligned_accum_fragment0),
            reinterpret_cast<AccumulatorAccessType const *>(
                &aligned_accum_fragment1)};

        OutputAccessType const *source_frag_ptr[2] = {
            reinterpret_cast<OutputAccessType const *>(&source_fragment[0]),
            reinterpret_cast<OutputAccessType const *>(&source_fragment[1])};

        int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
                                        OutputTileIterator::kElementsPerAccess;

        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < kOutputOpIterations; ++i) {
            // Call the output operators
            output_frag_ptr[0][i] =
                output_op0(compute_frag_ptr[0][i], source_frag_ptr[0][i]);
            output_frag_ptr[1][i] =
                output_op1(compute_frag_ptr[1][i], source_frag_ptr[1][i]);
            output_frag_final_ptr[i] =
                output_op2(output_frag_ptr[0][i], output_frag_ptr[1][i]);
        }
    }
};

}  // namespace threadblock
}  // namespace epilogue
}  // namespace cutlass
